#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/icmp.h>
#include <linux/tcp.h>
#include <linux/in.h>
 
#define MAX_CHECKING 4
#define MAX_CSUM_WORDS 750
 
static __always_inline __u32 sum16(const void* data, __u32 size, const void* data_end) {
    __u32 sum = 0;
    const __u16 *ptr = (const __u16 *)data;
 
    #pragma unroll
    for (int i = 0; i < MAX_CSUM_WORDS; ++i) {
        if ((const void *)(ptr + 1) > (data + size)) {
            break;
        }
 
        if ((const void *)(ptr + 1) > data_end) {
             return sum;
        }
 
        sum += *ptr;
        ptr++;
    }
 
    // Handle the potential odd byte at the end if size is odd
    if (size & 1) {
        const __u8 *byte_ptr = (const __u8 *)ptr; // ptr is now after the last full word
 
        // BPF Verifier check: Ensure the single byte read is within packet bounds
        if ((const void *)(byte_ptr + 1) <= data_end && (const void *)byte_ptr < data_end) {
            // In checksum calculation, the last odd byte is treated as the
            // high byte of a 16-bit word, padded with a zero low byte.
            // E.g., if the byte is 0xAB, it's treated as 0xAB00.
            sum += (__u16)(*byte_ptr) << 8;
        }
        // If the bounds check fails, we just return the sum calculated so far.
    }
 
    return sum;
}
 
 
SEC("xdp")
int tcp_bounce(struct xdp_md *ctx) {
    void *data = (void *)(long)ctx->data;
    void *data_end = (void *)(long)ctx->data_end;
 
    struct ethhdr *eth = data;
    if ((void *)eth + sizeof(*eth) > data_end)
        return XDP_PASS;  // not enough data
 
    if (eth->h_proto != bpf_htons(ETH_P_IP))
        return XDP_PASS;
 
    struct iphdr *iph = data + sizeof(*eth);
    if ((void *)iph + sizeof(*iph) > data_end)
        return XDP_PASS;
 
    if (iph->protocol != IPPROTO_TCP)
        return XDP_PASS;
 
    //check ip len
    int ip_hdr_len = iph->ihl*4;
    if((void *)iph + ip_hdr_len > data_end)
        return XDP_PASS;
 
    // convert to TCP
    struct tcphdr *tcph = (void *)iph + ip_hdr_len;
    if ((void *)tcph + sizeof(*tcph) > data_end)
        return XDP_PASS;
 
    if (!(tcph->syn) || tcph->ack)
        return XDP_DROP;
 
    // swap MAC addresses
    __u8 tmp_mac[ETH_ALEN];
    __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
    __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
    __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
 
    // swap IP addresses
    __be32 tmp_ip = iph->saddr;
    iph->saddr = iph->daddr;
    iph->daddr = tmp_ip;
 
    // TCP
    // swap port
    __be16 tmpsrcport = tcph->source;
    tcph->source = tcph->dest;
    tcph->dest = tmpsrcport;
 
    // syn+ack
    tcph->ack = 1;
    __u32 ack_seq = bpf_ntohl(tcph->seq) + 1;
    tcph->ack_seq = bpf_htonl(ack_seq);
 
 
    // checksum pseudo header
    __u32 csum = 0;
    tcph->check = (__be16)csum;
 
    if ((void *)&iph->saddr + 8 > data_end)
        return XDP_PASS;
    csum = bpf_csum_diff(0, 0, (__be32 *)&iph->saddr, 8, csum);
    __u16 tcp_len = bpf_ntohs(iph->tot_len) - ip_hdr_len;
    csum += (__u32)(bpf_htons(IPPROTO_TCP) << 16) | bpf_htons(tcp_len);
 
    csum += sum16(tcph, tcp_len, data_end);
 
    while (csum >> 16)
        csum = (csum & 0xFFFF) + (csum >> 16);
 
    tcph->check = (__be16)~csum;
 
    return XDP_TX;
}
 
char _license[] SEC("license") = "GPL";